#This script plots the VDEM media critical scores and the cosine similarity scores
#The cosine similarity scores are calculated for the leader word analyses
#It produces Figure 2 in the paper
library(boot)
library(dplyr)
library(tidyr)
library(lubridate)
library(purrr)
library(readr)
library(ggplot2)
library(ggthemes)
library(cowplot)

# Bring in data

# Define the list of countries
countries <- c("djazairess", "maghress", "masress", "sauress", "turess")
# Versions to process
sample_sizes <- c(1e4, 5e4, 1e5, 5e5, 1e6, 1.5e6)
formatted_sample_sizes <- sapply(sample_sizes, function(x) format(x, scientific = FALSE))
versions <- paste0(formatted_sample_sizes, "30k")

# Iterate through both versions and read the corresponding data
all_data <- bind_rows(lapply(versions, function(version) {
  bind_rows(lapply(countries, function(country) {
    cos_simsdf <-
      readRDS(paste0("data/output/cos_sims/", country, "/", "cos_simsdf_all", version, ".rds"))
    cos_simsdf %>%
      mutate(group = as.Date(group)) %>%
      arrange(group) %>%
      rename(yearwk = group,
             cos_sim = val) %>%
      mutate(country = country, version = version) # Add columns for country and version
  }))
}))

write_csv(all_data, "data/vdem/cos_sims_all.csv")

country_map <- c(
  djazairess = "Algeria", 
  maghress = "Morocco", 
  masress = "Egypt", 
  sauress = "Saudi Arabia", 
  turess = "Tunisia"
)

all_data <- all_data %>%
  mutate(country_name = country_map[country])

#date variable
all_data$year <- substr(all_data$yearwk, 1, 4)
all_data$year <- as.numeric(all_data$year)

#put cos_sum on 0-4 scale 
#min and max values of cos_sim
min_value <- min(all_data$cos_sim)
max_value <- max(all_data$cos_sim)

# Define the new min and max values
new_min <- 0
new_max <- 4

# Rescale mean_cos_sim to the 0-4 scale
all_data$rescaled_cos_sim <- ((all_data$cos_sim - min_value) / (max_value - min_value)) * (new_max - new_min) + new_min

# Group by country and year and calculate the standard deviation
result_df <- all_data %>%
  group_by(country_name, year) %>%
  summarize(std_dev_cos_sim = sd(rescaled_cos_sim, na.rm = TRUE)) 
  
# Group by country and year and calculate the mean
result_df2 <- all_data %>%
  group_by(country_name, year) %>%
  summarize(mean_cos_sim = mean(rescaled_cos_sim, na.rm = TRUE))

merged_df <- merge(result_df, result_df2, by = c("country_name", "year"))

# calculate CI
N <- length(unique(merged_df$country_name)) # 14

merged_df$lower_ci <- merged_df$mean_cos_sim - 1.96 * (merged_df$std_dev_cos_sim / sqrt(N))  # Calculate lower bound
merged_df$upper_ci <- merged_df$mean_cos_sim + 1.96 * (merged_df$std_dev_cos_sim / sqrt(N))  # Calculate upper bound

#bring in VDEM
#Algeria
algeria_vdem_media_critical <- read_csv("data/vdem/djazairess/vdem_media_critical.csv")

algeria_vdem_media_critical <- algeria_vdem_media_critical %>%
  rename(v2mecrit = Algeria)  %>%
  rename(v2mecrit_ci_low = `Algeria CI (Low)`) %>%
  rename(v2mecrit_ci_high = `Algeria CI (High)`)

algeria_vdem_media_critical$country_name <- "Algeria" 

#Morocco
morocco_vdem_media_critical <- read_csv("data/vdem/maghress/vdem_media_critical.csv")

morocco_vdem_media_critical <- morocco_vdem_media_critical %>%
  rename(v2mecrit = Morocco)  %>%
  rename(v2mecrit_ci_low = `Morocco CI (Low)`) %>%
  rename(v2mecrit_ci_high = `Morocco CI (High)`)

morocco_vdem_media_critical$country_name <- "Morocco" 

#Egypt
egypt_vdem_media_critical <- read_csv("data/vdem/masress/vdem_media_critical.csv")

egypt_vdem_media_critical <- egypt_vdem_media_critical %>%
  rename(v2mecrit = `Egypt`)  %>%
  rename(v2mecrit_ci_low = `Egypt CI (Low)`) %>%
  rename(v2mecrit_ci_high = `Egypt CI (High)`)

egypt_vdem_media_critical$country_name <- "Egypt" 


#Saudi Arabia
saudi_vdem_media_critical <- read_csv("data/vdem/sauress/vdem_media_critical.csv")

saudi_vdem_media_critical <- saudi_vdem_media_critical %>%
  rename(v2mecrit = `Saudi Arabia`)  %>%
  rename(v2mecrit_ci_low = `Saudi Arabia CI (Low)`) %>%
  rename(v2mecrit_ci_high = `Saudi Arabia CI (High)`)

saudi_vdem_media_critical$country_name <- "Saudi Arabia" 

#Tunisia
tunisia_vdem_media_critical <- read_csv("data/vdem/turess/vdem_media_critical.csv")

tunisia_vdem_media_critical <- tunisia_vdem_media_critical %>%
  rename(v2mecrit = `Tunisia`)  %>%
  rename(v2mecrit_ci_low = `Tunisia CI (Low)`) %>%
  rename(v2mecrit_ci_high = `Tunisia CI (High)`)

tunisia_vdem_media_critical$country_name <- "Tunisia" 

vdem_v2mecrit <- rbind(tunisia_vdem_media_critical, saudi_vdem_media_critical, 
                       egypt_vdem_media_critical, morocco_vdem_media_critical, 
                       algeria_vdem_media_critical)

vdem_v2mecrit <- vdem_v2mecrit %>%
  filter(Year <= 2020)

vdem_v2mecrit <- vdem_v2mecrit %>%
  rename(year = Year) 

vdem_medcrit_comparison <- merge(merged_df, vdem_v2mecrit, by = c("country_name", "year"))

# Calculate Pearson's r by Country and Year
correlation_data <- vdem_medcrit_comparison %>%
  group_by(country_name) %>%
  summarise(pearson_R = cor(mean_cos_sim, v2mecrit, use = "pairwise.complete.obs"))

vdem_v2mecrit$country_name <- factor(vdem_v2mecrit$country_name, levels = c("Egypt", "Tunisia", "Algeria", "Morocco", "Saudi Arabia"))

g1 <- vdem_v2mecrit %>%
  ggplot(aes(x = year, y = v2mecrit)) +
  geom_point(alpha = .25) +
  geom_ribbon(aes(ymin = v2mecrit_ci_low, ymax = v2mecrit_ci_high), alpha = .2) +
  scale_x_continuous(breaks = c(2010, 2015, 2020)) +
  coord_cartesian(ylim=c(0.0, 3.0), xlim = c(2008, 2020)) +
  scale_y_continuous(
    labels = scales::number_format(accuracy = 0.1)) +
  theme_tufte(base_family = "Helvetica") +
  labs(x = "Year",
       y = "V-Dem score") + 
  theme(
    legend.position = "right",
    axis.text.x = element_text(size = 20),
    axis.text.y = element_text(size = 20),
    axis.title.x = element_text(size = 15),
    axis.title.y = element_text(size = 15),
    legend.text = element_text(size = 15),
    legend.title = element_text(size = 20),
    panel.border = element_rect(
      colour = "black",
      fill = NA,
      size = 1
    ),
    plot.background = element_rect(fill = "white", colour = NA),
    panel.grid.major = element_line(size = 0.1, linetype = "solid"),
    panel.grid.minor = element_line(size = 0.1, linetype = "solid"),
    strip.text = element_text(size = 20)
  ) +
  facet_wrap( ~ country_name, ncol = 5)

# Define the list of countries
countries <- c("djazairess", "maghress", "masress", "sauress", "turess")
# Versions to process
sample_sizes <- c(1e4, 5e4, 1e5, 5e5, 1e6, 1.5e6)
formatted_sample_sizes <- sapply(sample_sizes, function(x) format(x, scientific = FALSE))
versions <- paste0(formatted_sample_sizes, "30k")

# Iterate through both versions and read the corresponding data
all_data <- bind_rows(lapply(versions, function(version) {
  bind_rows(lapply(countries, function(country) {
    cos_simsdf <-
      readRDS(paste0("data/output/cos_sims/", country, "/", "cos_simsdf_all", version, ".rds"))
    cos_simsdf %>%
      mutate(group = as.Date(group)) %>%
      arrange(group) %>%
      rename(yearwk = group,
             cos_sim = val) %>%
      mutate(country = country, version = version) # Add columns for country and version
  }))
}))

country_map <- c(
  djazairess = "Algeria", 
  maghress = "Morocco", 
  masress = "Egypt", 
  sauress = "Saudi Arabia", 
  turess = "Tunisia"
)

all_data <- all_data %>%
  mutate(country_name = country_map[country])

# Order of versions
ordered_versions <- c("1000030k", "5000030k", "10000030k", "50000030k", "100000030k", "150000030k")
# Convert 'version' to a factor with specified order
all_data$version <- factor(all_data$version, levels = ordered_versions)
all_data$country_name <- factor(all_data$country_name, levels = c("Egypt", "Tunisia", "Algeria", "Morocco", "Saudi Arabia"))


# Plot the combined data
g2 <- all_data %>%
  filter(version %in% c("150000030k")) %>%
  ggplot(aes(x = yearwk, y = cos_sim)) +
  geom_point(alpha = .25) +
  geom_smooth(
    method = "loess",
    size = 1,
    span = .5,
    fill = "white",
    col = "black",
  ) +
  theme_tufte(base_family = "Helvetica") +
  labs(x = "Year-week",
       y = "Cosine similarity, leaders : opposition index") + # Label for the color legend
  ylim(-.2, 0.2) +
  theme(
    legend.position = "right",
    axis.text.x = element_text(size = 20),
    axis.text.y = element_text(size = 20),
    axis.title.x = element_text(size = 15),
    axis.title.y = element_text(size = 15),
    legend.text = element_text(size = 15),
    legend.title = element_text(size = 20),
    panel.border = element_rect(
      colour = "black",
      fill = NA,
      size = 1
    ),
    plot.background = element_rect(fill = "white", colour = NA),
    panel.grid.major = element_line(size = 0.1, linetype = "solid"),
    panel.grid.minor = element_line(size = 0.1, linetype = "solid"),
    strip.text = element_text(size = 20)
  ) +
  facet_wrap( ~ country_name, ncol = 5)

png(
  "plots/fig2.png",
  width = 750,
  height = 300,
  units = 'mm',
  res = 300
)
plot_grid(g1, g2, ncol=1, nrow = 2, labels = "AUTO")
dev.off()